# coding:utf-8

import sys
from src.constant.file_and_path_constant import FileAndPathConstant

sys.path.append(FileAndPathConstant.System_Drive + '\github-repository\hades\dev-project\jormungandr')

from src.manager.oracle_manager import OracleManager
from src.constant.oracle import Oracle
from src.manager.log_manager import LogManager
from src.manager.file_manager import FileManager
from src.manager.knn_manager import KnnManager
from src.constant.file_and_path_constant import FileAndPathConstant
from src.config.knn_config import KNNConfig
from src.config.date_config import DateConfig
from numpy import *
import os
import shutil

Logger = LogManager.get_logger(__name__)


class KnnHandler:
    """
    k近邻算法处理器
    """

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        self.oracle_manager = OracleManager(Oracle.Username, Oracle.Password, Oracle.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def close_cursor_and_connect(self):
        """
        关闭游标和数据库连接
        :return:
        """
        self.cursor.close()
        self.oracle_manager.connect_close()

    def delete_training_data_file_and_testing_data_file(self):
        """
        删除所有的特征值数据文件，包括训练数据和测试数据
        :return:
        """
        Logger.info('删除所有的特征值数据文件，包括训练数据和测试数据')

        characteristic_value_folder_list = os.listdir(FileAndPathConstant.Knn_Vector_File_Path)
        # 分别存储特征值和标签
        for characteristic_value_folder in characteristic_value_folder_list:
            training_data_file_path = FileAndPathConstant.Knn_Vector_File_Path + '/' + characteristic_value_folder + FileAndPathConstant.Knn_Training_Data_Path
            testing_data_file_path = FileAndPathConstant.Knn_Vector_File_Path + '/' + characteristic_value_folder + FileAndPathConstant.Knn_Testing_Data_Path

            # 删除训练数据文件
            training_data_file_list = os.listdir(training_data_file_path)
            for training_data_file in training_data_file_list:
                # 删除文件
                os.remove(training_data_file_path + '/' + training_data_file)
            Logger.info('目录【' + training_data_file_path + '】中的文件都被删除了')

            # 删除测试数据文件
            testing_data_file_list = os.listdir(testing_data_file_path)
            for testing_data_file in testing_data_file_list:
                # 删除文件
                os.remove(testing_data_file_path + '/' + testing_data_file)
            Logger.info('目录【' + testing_data_file_path + '】中的文件都被删除了')

    def collect_and_store_data_in_file(self, characteristic_value_number_list):
        """
        查询特征向量和标签，并存储在文件中
        :param characteristic_value_number_list:
        :return:
        """

        for characteristic_value_number in characteristic_value_number_list:
            Logger.info('查询特征向量和标签，并存储在文件中，特征值数量为【' + str(characteristic_value_number) + '】')

            # 作为变量characteristic_value_number的备份变量
            characteristic_value_number_for_filename = str(characteristic_value_number)

            # 查找股票code
            self.cursor.execute('select distinct t.code_ from mdl_stock_analysis t')
            stock_code_tuple_list = self.cursor.fetchall()
            for stock_code_tuple in stock_code_tuple_list:
                stock_code = stock_code_tuple[0]
                Logger.info('开始收集股票【' + stock_code + '】的数据，特征值数量为【' + str(characteristic_value_number) + '】')

                # 查找每只股票，并按照升序排列
                self.cursor.execute("select * from mdl_stock_analysis t "
                                    "where t.date_ between to_date('" + DateConfig.Mdl_Stock_Analysis_Begin_Date + "', 'yyyy-mm-dd') and to_date('" + DateConfig.Mdl_Stock_Analysis_End_Date + "', 'yyyy-mm-dd') "
                                    "and t.code_='" + stock_code + "' "
                                    "order by t.date_ asc")
                mdl_stock_analysis_tuple_list = self.cursor.fetchall()
                for index in range(len(mdl_stock_analysis_tuple_list)):
                    second_mdl_stock_analysis_tuple_list = mdl_stock_analysis_tuple_list[index:]

                    # 当表中记录数不够特征值+标签的数量时，跳转到上一个for
                    if len(second_mdl_stock_analysis_tuple_list) <= int(characteristic_value_number_for_filename):
                        break

                    # 存储特征值和特征向量。第一列为日期，第二列为股票code，最后一列为标签，其余列为特征值
                    up_down_vector_list = list()

                    for second_mdl_stock_analysis_tuple in second_mdl_stock_analysis_tuple_list:
                        # 日期
                        if len(up_down_vector_list) == 0:
                            up_down_vector_list.append(second_mdl_stock_analysis_tuple[1].strftime('%Y-%m-%d %H:%M:%S'))
                        # 股票代码
                        if len(up_down_vector_list) == 1:
                            up_down_vector_list.append(second_mdl_stock_analysis_tuple[2])
                        # ma_trend
                        if len(up_down_vector_list) == 2:
                            up_down_vector_list.append(str(second_mdl_stock_analysis_tuple[3]))
                        # macd_trend
                        if len(up_down_vector_list) == 3:
                            up_down_vector_list.append(str(second_mdl_stock_analysis_tuple[4]))
                        # kd_trend
                        if len(up_down_vector_list) == 4:
                            up_down_vector_list.append(str(second_mdl_stock_analysis_tuple[5]))
                        # 特征值
                        if len(up_down_vector_list) >= 5:
                            up_down_vector_list.append(str(second_mdl_stock_analysis_tuple[6]))
                            characteristic_value_number = characteristic_value_number - 1

                        # 当up_down_vector_list中的数据收集好之后，保存到文件中
                        if len(up_down_vector_list) == (6 + int(characteristic_value_number_for_filename)):
                            FileManager.write(
                                FileAndPathConstant.Knn_Vector_File_Path + '/' + characteristic_value_number_for_filename + FileAndPathConstant.Knn_Training_Data_Path + '/' + stock_code + FileAndPathConstant.Knn_Vector_File_Extension_Name,
                                'a', ','.join(up_down_vector_list) + '\n')
                            characteristic_value_number = int(characteristic_value_number_for_filename)
                            break
        self.close_cursor_and_connect()

    def distinguish_training_data_and_testing_data(self):
        """
        从训练数据中挑选测试数据，并将其移动到testing目录中
        :return:
        """
        Logger.info('从训练数据中挑选测试数据，并将其移动到testing目录中')

        characteristic_value_path_list = os.listdir(FileAndPathConstant.Knn_Vector_File_Path)
        for vector_file_path in characteristic_value_path_list:
            vector_file_list = os.listdir(
                FileAndPathConstant.Knn_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Knn_Training_Data_Path)

            # 计算应该有多少测试数据文件
            testing_data_file_number = round(len(vector_file_list) * KNNConfig.Testing_DataPercentage)
            # 随机地从训练数据中挑选测试数据
            while testing_data_file_number > 0:
                # 计算随机数
                random_number = random.randint(0, len(vector_file_list) - 1)
                # 拷贝文件
                shutil.copyfile(
                    FileAndPathConstant.Knn_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Knn_Training_Data_Path + '/' +
                    vector_file_list[random_number],
                    FileAndPathConstant.Knn_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Knn_Testing_Data_Path + '/' +
                    vector_file_list[random_number])
                # 删除文件
                os.remove(
                    FileAndPathConstant.Knn_Vector_File_Path + '/' + vector_file_path + '/' + FileAndPathConstant.Knn_Training_Data_Path + '/' +
                    vector_file_list[random_number])
                # 删除vector_file_list列表中索引为random_number的元素
                del vector_file_list[random_number]
                # 减一，以便继续迭代
                testing_data_file_number = testing_data_file_number - 1

    def do_knn(self, characteristic_value, neighbor_number):
        """
        先收集训练集数据和测试集数据，再执行knn算法
        :return:
        """
        # 训练数据集
        training_group_ndarray, training_label_list = self.create_data_set(
            FileAndPathConstant.Knn_Vector_File_Path + '/' + str(characteristic_value) + FileAndPathConstant.Knn_Training_Data_Path)
        # 测试数据集
        testing_group_ndarray, testing_label_list = self.create_data_set(
            FileAndPathConstant.Knn_Vector_File_Path + '/' + str(characteristic_value) + FileAndPathConstant.Knn_Testing_Data_Path)

        # Logger.info('开始执行knn算法，特征值数量为【' + str(characteristic_value) + '】，'
        #                                                               '【' + neighbor_number + '】个近邻')
        Logger.info('开始执行knn算法')
        Logger.info(characteristic_value)
        Logger.info(neighbor_number)

        # 测试正确的次数
        collect_number = int()
        # 使用knn算法，开始测试
        knn_manager = KnnManager()
        for index, value in enumerate(testing_group_ndarray):
            result = knn_manager.classify0(testing_group_ndarray[index], training_group_ndarray, training_label_list, neighbor_number)
            if result == testing_label_list[index]:
                collect_number = collect_number + 1
            # Logger.info('执行到索引为【' + str(index) + '】的数据集，此时正确率为【' + str(collect_number/(index + 1)*100) + '】')
        Logger.info('正确率：' + str(collect_number / len(testing_group_ndarray) * 100))

    def create_data_set(self, path):
        """
        读取指定路径下的文件，并将特征值和标签分别存储在numpy.ndarray类型的数组和一维数组中。可以用于创建训练数据集或测试数据集
        :param path:
        :return:
        """
        Logger.info('读取指定路径下的文件，并将特征值和标签分别存储在numpy.ndarray类型的数组和一维数组中')

        # 特征值二维数组
        group_2d_list = []
        # 标签一维数组
        label_list = []
        file_list = os.listdir(path)
        # 分别存储特征值和标签
        for file in file_list:
            path_and_file = os.path.join(path, file)
            content = FileManager.read(path_and_file, 'r')
            row_list = content.split('\n')
            for row in row_list:
                if row != None and row != '':
                    # 表示group_2d_list中的一个元素（也是一维数组）
                    group = []
                    # 将一行切分为数组
                    column_list = row.split(',')
                    for i in range(len(column_list)):
                        # 存储特征值
                        if i >= 2 and i != len(column_list) - 1:
                            group.append(int(column_list[i]))
                        # 存储标签
                        if i == len(column_list) - 1:
                            label_list.append(int(column_list[i]))
                    # 存储特征值
                    group_2d_list.append(group)
        return array(group_2d_list), label_list
